{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Knet RNN example\n",
    "**TODO**: Use the new RNN interface, add dropout?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "true"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Knet\n",
    "True=true # so we can read the python params\n",
    "include(\"common/params_lstm.py\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS: Linux\n",
      "Julia: 1.5.0\n",
      "Knet: 1.4.0\n",
      "GPU: Quadro M2000\n",
      "Tesla P4\n",
      "\n"
     ]
    }
   ],
   "source": [
    "println(\"OS: \", Sys.KERNEL)\n",
    "println(\"Julia: \", VERSION)\n",
    "println(\"Knet: \", Pkg.dependencies()[Base.UUID(\"1902f260-5fb4-5aff-8c31-6271790ab950\")].version)\n",
    "println(\"GPU: \", read(`nvidia-smi --query-gpu=name --format=csv,noheader`,String))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model: this uses an outdated interface, see tutorial/70.imdb.ipynb for a more up-to-date implementation\n",
    "function initmodel()\n",
    "    rnn = RNN(EMBEDSIZE,NUMHIDDEN; rnnType=:gru, atype=Knet.array_type[])\n",
    "    inputMatrix = param(EMBEDSIZE,MAXFEATURES)\n",
    "    outputMatrix = param(2,NUMHIDDEN)\n",
    "    return (rnn,inputMatrix,outputMatrix)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "loss (generic function with 1 method)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# define loss and its gradient\n",
    "function predict(model, inputs)\n",
    "    rnn, inputMatrix, outputMatrix = model # (1,1,W), (X,V), (2,H)\n",
    "    indices = permutedims(hcat(inputs...)) # (B,T)\n",
    "    rnnInput = inputMatrix[:,indices] # (X,B,T)\n",
    "    rnnOutput = rnn(rnnInput) # (H,B,T)\n",
    "    return outputMatrix * rnnOutput[:,:,end] # (2,H) * (H,B) = (2,B)\n",
    "end\n",
    "\n",
    "loss(model,x,y)=nll(predict(model,x),y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading IMDB...\n",
      "└ @ Main /userfiles/dyuret/.julia/dev/Knet/data/imdb.jl:57\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  9.476542 seconds (26.67 M allocations: 1.390 GiB, 8.57% gc time)\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "include(Knet.dir(\"data\",\"imdb.jl\"))\n",
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=MAXFEATURES)\n",
    "for d in (xtrn,ytrn,xtst,ytst); println(summary(d)); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150-element Array{String,1}:\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " \"reilly\"\n",
       " ⋮\n",
       " \"tommy\"\n",
       " \"davidson\"\n",
       " \"and\"\n",
       " \"damon\"\n",
       " \"wayans\"\n",
       " \"moving\"\n",
       " \"performances\"\n",
       " \"in\"\n",
       " \"spike\"\n",
       " \"lee's\"\n",
       " \"satire\"\n",
       " \"bamboozled\""
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imdbarray = Array{String}(undef,88584)\n",
    "for (k,v) in imdbdict; imdbarray[v]=k; end\n",
    "imdbarray[xtrn[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "390-element Knet.Train20.Data{Tuple{Array{Array{Int32,1},1},Array{Int8,1}}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prepare for training\n",
    "model = nothing; GC.gc(true); # Reclaim memory from previous run\n",
    "model = initmodel()\n",
    "dtrn = minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 20.553846 seconds (27.02 M allocations: 1.438 GiB, 2.18% gc time)\n"
     ]
    }
   ],
   "source": [
    "# cold start\n",
    "@time adam!((x,y)->loss(model,x,y), dtrn, lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(GRU(input=125,hidden=100), P(KnetArray{Float32,2}(125,30000)), P(KnetArray{Float32,2}(2,100)))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prepare for training\n",
    "model = nothing; GC.gc(true); # Reclaim memory from previous run\n",
    "model = initmodel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Training...\n",
      "└ @ Main In[10]:2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  4.760009 seconds (686.56 k allocations: 90.851 MiB, 0.85% gc time)\n",
      "  4.225123 seconds (683.99 k allocations: 90.816 MiB, 0.64% gc time)\n",
      "  4.236897 seconds (684.67 k allocations: 90.788 MiB, 0.85% gc time)\n",
      " 13.233227 seconds (2.07 M allocations: 272.768 MiB, 0.83% gc time)\n"
     ]
    }
   ],
   "source": [
    "# 29s\n",
    "@info(\"Training...\")\n",
    "@time for epoch in 1:EPOCHS\n",
    "    @time adam!((x,y)->loss(model,x,y), dtrn, lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Testing...\n",
      "└ @ Main In[11]:1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  2.739723 seconds (3.90 M allocations: 250.743 MiB, 2.82% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.854326923076923"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@info(\"Testing...\")\n",
    "@time accuracy(x->predict(model,x), data=minibatch(xtst,ytst,BATCHSIZE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Julia 1.5.0",
   "language": "julia",
   "name": "julia-1.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
